import cv2
import random
import numpy as np
import onnxruntime as ort
import argparse
from pathlib import Path

# 开启gpu，如果报错使用False
cuda = True

# 加载模型，如果帧率帧，额可以改为：yolov7-tiny.onnx
w = "yolov7-tiny.onnx"

providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = ort.InferenceSession(w, providers=providers)

# 定义标签
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 
         'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 
         'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 
         'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 
         'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 
         'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 
         'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 
         'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 
         'hair drier', 'toothbrush']

# 每种标签随机一个颜色，识别效果中好区分
colors = {name:[random.randint(0, 255) for _ in range(3)] for i,name in enumerate(names)}

# 等比例缩放图片
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
    
    # 获取图像宽度和高度(height, width)
    shape = im.shape[:2]  

    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # 计算缩放比例
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    
    if not scaleup: 
        r = min(r, 1.0)
        
    # 计算按比例缩放后的实际宽高
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    
    # 计算缩放后宽高与(640,640)的宽高偏移量，方便填充
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]

    if auto:  
        dw, dh = np.mod(dw, stride), np.mod(dh, stride) 

    dw /= 2  # 计算图像左右两侧的偏移量
    dh /= 2  # 计算图像上下两侧的偏移量

    if shape[::-1] != new_unpad:  
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    
    # 不足640的地方添加指定颜色
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  
    return im, r, (dw, dh)

def detect(Class, Save_path, Video_path):
    # 打开摄像头，表示第0个本机摄像头，也可以直接改为rtsp拉流地址
    cap = cv2.VideoCapture(Video_path)

    while True:
        # 读取视频帧
        ret, frame = cap.read()

        # 获取帧索引
        frame_index = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
        
        if not ret:
            continue
        
        # 画面镜像反转可以不要
        # frame = cv2.flip(frame, 1)

        # 将BGR图像转换为 RGB
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # shape : (640,640,3)
        image, ratio, dwdh = letterbox(image, auto=False)

        # 将 image 的维度顺序从 (height, width, channels) 转变为 (channels, height, width),shape : (3, 640,640)
        image = image.transpose((2, 0, 1))

        # 插入新维度的目的可能是为了将 (channels, height, width) 转换为 (batch,channels, height, width),shape : (1,3, 640,640)
        image = np.expand_dims(image, 0)

        # image转换为内存连续的数组，运行速度更快,不加不影响最终效果
        image = np.ascontiguousarray(image)

        im = image.astype(np.float32)

        # 归一化
        im /= 255

        inname = [i.name for i in session.get_inputs()]
        
        outname = [i.name for i in session.get_outputs()]
        
        # 将帧输入模型进行推理
        outputs = session.run(outname, {inname[0]:im})[0]
        
        if frame_index % 6 == 0:
            # 对推理结果进行后处理（例如后处理，可视化等）
            for i,(batch_id,x0,y0,x1,y1,cls_id,score) in enumerate(outputs):
                box = np.array([x0,y0,x1,y1])
                box -= np.array(dwdh*2)
                box /= ratio
                box = box.round().astype(np.int32).tolist()
                score = round(float(score),3)
                name = names[int(cls_id)]
                if name != Class or score < 0.75:
                    continue
                # 截取比预测框更大一点的区域
                y = int((box[3] - box[1]) * 0.05)
                x = int((box[2] - box[0]) * 0.05)
                c = frame[(box[1] - x):(box[3] + x), (box[0] - y):(box[2] + y)]
                try:
                    cv2.imwrite(str(Path(Save_path)/("{:06d}".format(frame_index) + '.png')), c)
                except Exception as e:
                    print(f'发生了一个异常： {e}')
                # color = colors[name]
                # name += ' '+str(score)
                # cv2.rectangle(frame,box[:2],box[2:],color,2)
                # cv2.putText(frame,name,(box[0], box[1] - 2),cv2.FONT_HERSHEY_SIMPLEX,0.75,[225, 255, 255],thickness=2)

        # 显示结果
        cv2.imshow("results", frame)

        # 等待按键退出
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # 释放资源
    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--Class', type=str, default='car', help='The class to save')
    parser.add_argument('--Save_path', type=str, default=r'save', help='initial save path')
    parser.add_argument('--Video_path', type=str, default=r'video\IMG_5459.MP4', help='initial Video path')
    opt = parser.parse_args()
    detect(opt.Class, opt.Save_path, opt.Video_path)
